import time
from experiments.utils import load_dataset_safely, seed_everything, safe_import


def tpot_baseline(dataset_name: str, max_time_mins: int = 5, random_state: int = 42):
    seed_everything(random_state)
    tpot, err = safe_import("tpot")
    if tpot is None:
        raise RuntimeError(f"TPOT not installed: {err}")

    data, msg = load_dataset_safely(dataset_name)
    if data is None:
        raise RuntimeError(msg)

    start = time.time()

    X_train, y_train = data["X_train"], data["y_train"]
    X_val, y_val = data["X_val"], data["y_val"]

    clf = tpot.TPOTClassifier(
        population_size=20,
        max_time_mins=max_time_mins,
        random_state=random_state,
        n_jobs=-1
    )

    try:
        clf.fit(X_train, y_train)
        # use fitted sklearn pipeline instead of TPOTClassifier.score()
        score = clf.fitted_pipeline_.score(X_val, y_val)
    except Exception:
        score = 0.0

    duration = time.time() - start
    return {
        "val_score": float(score),
        "time_sec": duration,
    }


if __name__ == "__main__":
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument("--dataset", default="iris")
    p.add_argument("--time_mins", type=int, default=5)
    args = p.parse_args()
    res = tpot_baseline(args.dataset, args.time_mins)
    print(res)
